{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN-循环神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\ProgramData\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.python import debug as tf_debug\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 超参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_steps = 10\n",
    "batch_size = 200\n",
    "num_classes = 2\n",
    "state_size = 16\n",
    "learning_rate = 0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "就是按照文章中提到的规则，这里生成1000000个\n",
    "'''\n",
    "def gen_data(size=1000000):\n",
    "    X = np.array(np.random.choice(2, size=(size,)))\n",
    "    Y = []\n",
    "    '''根据规则生成Y'''\n",
    "    for i in range(size):   \n",
    "        threshold = 0.5\n",
    "        if X[i-3] == 1:\n",
    "            threshold += 0.5\n",
    "        if X[i-8] == 1:\n",
    "            threshold -=0.25\n",
    "        if np.random.rand() > threshold:\n",
    "            Y.append(0)\n",
    "        else:\n",
    "            Y.append(1)\n",
    "    return X, np.array(Y)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成batch数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_batch(raw_data, batch_size, num_step):\n",
    "    raw_x, raw_y = raw_data\n",
    "    data_length = len(raw_x)\n",
    "    batch_patition_length = data_length // batch_size                         # ->5000\n",
    "    data_x = np.zeros([batch_size, batch_patition_length], dtype=np.int32)    # ->(200, 5000)\n",
    "    data_y = np.zeros([batch_size, batch_patition_length], dtype=np.int32)    # ->(200, 5000)\n",
    "    '''填到矩阵的对应位置'''\n",
    "    for i in range(batch_size):\n",
    "        data_x[i] = raw_x[batch_patition_length*i:batch_patition_length*(i+1)]# 每一行取batch_patition_length个数，即5000\n",
    "        data_y[i] = raw_y[batch_patition_length*i:batch_patition_length*(i+1)]\n",
    "    epoch_size = batch_patition_length // num_steps                           # ->5000/5=1000 就是每一轮的大小\n",
    "    for i in range(epoch_size):   # 抽取 epoch_size 个数据\n",
    "        x = data_x[:, i * num_steps:(i + 1) * num_steps]                      # ->(200, 5)\n",
    "        y = data_y[:, i * num_steps:(i + 1) * num_steps]\n",
    "        yield (x, y)    # yield 是生成器，生成器函数在生成值后会自动挂起并暂停他们的执行和状态（最后就是for循环结束后的结果，共有1000个(x, y)）\n",
    "def gen_epochs(n, num_steps):\n",
    "    for i in range(n):\n",
    "        yield gen_batch(gen_data(), batch_size, num_steps)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义placeholder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "x = tf.placeholder(tf.int32, [batch_size, num_steps], name=\"x\")\n",
    "y = tf.placeholder(tf.int32, [batch_size, num_steps], name='y')\n",
    "init_state = tf.zeros([batch_size, state_size])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNN输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnn_inputs = tf.one_hot(x, num_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义RNN cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the retry module or similar alternatives.\n",
      "Tensor(\"rnn/transpose_1:0\", shape=(200, 10, 16), dtype=float32)\n",
      "Tensor(\"rnn/while/Exit_3:0\", shape=(200, 16), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "cell = tf.contrib.rnn.BasicRNNCell(num_units=state_size)\n",
    "rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=init_state)\n",
    "print(rnn_outputs)\n",
    "print(final_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测，损失，优化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.variable_scope('softmax'):\n",
    "    W = tf.get_variable('W', [state_size, num_classes])\n",
    "    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits = tf.reshape(tf.matmul(tf.reshape(rnn_outputs, [-1, state_size]), W) +b, \\\n",
    "                    shape=[batch_size, num_steps, num_classes])\n",
    "predictions = tf.nn.softmax(logits)\n",
    "\n",
    "y_as_list = tf.unstack(y, num=num_steps, axis=1)\n",
    "losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)\n",
    "total_loss = tf.reduce_mean(losses)\n",
    "train_step = tf.train.AdagradOptimizer(learning_rate).minimize(total_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "epoch 0\n",
      "第 100 步的平均损失 0.5054965761303901\n",
      "第 200 步的平均损失 0.47457211136817934\n",
      "第 300 步的平均损失 0.47125625878572464\n",
      "第 400 步的平均损失 0.4684878721833229\n",
      "\n",
      "epoch 1\n",
      "第 100 步的平均损失 0.47112494468688965\n",
      "第 200 步的平均损失 0.4625227051973343\n",
      "第 300 步的平均损失 0.46225157380104065\n",
      "第 400 步的平均损失 0.4620427623391151\n",
      "\n",
      "epoch 2\n",
      "第 100 步的平均损失 0.4671410319209099\n",
      "第 200 步的平均损失 0.46128752201795575\n",
      "第 300 步的平均损失 0.4601376268267632\n",
      "第 400 步的平均损失 0.4612331798672676\n",
      "\n",
      "epoch 3\n",
      "第 100 步的平均损失 0.4680284786224365\n",
      "第 200 步的平均损失 0.4586708062887192\n",
      "第 300 步的平均损失 0.45836528539657595\n",
      "第 400 步的平均损失 0.45970610201358797\n",
      "\n",
      "epoch 4\n",
      "第 100 步的平均损失 0.4671090090274811\n",
      "第 200 步的平均损失 0.45870311588048934\n",
      "第 300 步的平均损失 0.4596740183234215\n",
      "第 400 步的平均损失 0.45912133157253265\n",
      "0.5054965761303901\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl8VPW9//HXJ/sKWVmSAAmLLCqgRhYFpO5boVq1rtWf+sPaurS116u3vbYPvbeLvWp7/WmrVVtrbd1aK0VccMENQQKy7zthScKSAAkh2/f3x0wwhCyTZJJJ5ryfj0cemTnznTmfHIb3nPme7/kec84hIiLeEBHqAkREpOso9EVEPEShLyLiIQp9EREPUeiLiHiIQl9ExEMU+iIiHqLQFxHxEIW+iIiHRIW6gMYyMjJcbm5uqMsQEelRFi1atMc5l9lau24X+rm5uRQUFIS6DBGRHsXMtgbSTt07IiIeotAXEfEQhb6IiIco9EVEPEShLyLiIQp9EREPUeiLiHhI2IT+jtLDPPLuWrbuLQ91KSIi3VbYhH5ZRTWPf7CBlTsPhLoUEZFuK2xCPzs1HoDC/RUhrkREpPsKm9DvHR9NclwUhfsPh7oUEZFuK2xCHyAnNYEdCn0RkWaFWejHa09fRKQFYRX62Snx7Cg9jHMu1KWIiHRLYRX6OanxHDpSQ9nh6lCXIiLSLYVd6APq4hERaUaYhX4CoNAXEWlOmIW+xuqLiLQkrEK/d3w0iTGR7CjVnr6ISFPCKvTNjJzUBHXviIg0I6xCH3zTMSj0RUSaFnahn5Mazw716YuINCksQ/9AZQ0HKjVWX0SksbAL/ewU37BNzcEjInK8sAt9naAlItK8MA599euLiDQWdqGflhhDXHSEundERJoQdqGvsfoiIs0LKPTN7EIzW2tmG8zsviYev8nMSsxsif/n1gaP3Whm6/0/Nwaz+OZkp8RTWKruHRGRxqJaa2BmkcATwHlAIbDQzGY651Y1avqyc+6ORs9NA34K5AMOWOR/7v6gVN+MnNR4lhWWduYqRER6pED29McBG5xzm5xzVcBLwPQAX/8CYI5zbp8/6OcAF7av1MDlpCawv6Ka8iM1nb0qEZEeJZDQzwa2N7hf6F/W2DfNbJmZvWZmA9ryXDObYWYFZlZQUlISYOktFOwfwaOJ10REjhVI6FsTyxpfj/BfQK5zbjTwHvB8G56Lc+5p51y+cy4/MzMzgJJapmGbIiJNCyT0C4EBDe7nADsbNnDO7XXOHfHf/QNwWqDP7Qw5KTpBS0SkKYGE/kJgmJnlmVkMcDUws2EDM+vf4O40YLX/9jvA+WaWamapwPn+ZZ0qIymWmCiN1RcRaazV0TvOuRozuwNfWEcCzznnVprZg0CBc24mcJeZTQNqgH3ATf7n7jOzh/B9cAA86Jzb1wl/xzEiIoycFE2xLCLSWKuhD+Ccmw3MbrTsgQa37wfub+a5zwHPdaDGdslOjadQB3JFRI4Rdmfk1tO8+iIixwvj0E9gz6EqDlfVhroUEZFuI2xDPztFY/VFRBoL29DXWH0RkeOFbehn62IqIiLHCdvQ75McR3SkqXtHRKSBsA39yAgjS2P1RUSOEbahD76DuRq2KSLylbAO/ZxU7emLiDQU1qGfnZJA8cEjVFZrrL6ICIR56NcP29xVVhniSkREugdPhL7G6ouI+IR16B+9gpb69UVEgDAP/X694oiMMB3MFRHxC+vQj4qMoH/vOHXviIj4hXXog3+svs7KFREBPBD6OakJ6t4REfEL+9DPTo1n94FKqmrqQl2KiEjIhX3o56TG4xzs1lh9ERFvhD5orL6ICHgh9FMSAHSRdBERPBD6/XrHEWG6mIqICHgg9GOiIujbS2P1RUTAA6EPvn59TcUgIuKZ0NdYfRER8EjoZ6f4xurX1Gqsvoh4mydCPyc1nto6x+4DGqsvIt7mkdD3D9tUF4+IeJwnQl/z6ouI+Hgi9LNS4gDt6YuIeCL0Y6Mi6ZMcq7H6IuJ5ngh98I/V11QMIuJxHgp9jdUXEQko9M3sQjNba2YbzOy+FtpdYWbOzPL992PM7I9mttzMlprZ1CDV3WbZqfHsKjtMbZ0LVQkiIiHXauibWSTwBHARMAq4xsxGNdEuGbgLWNBg8f8FcM6dDJwHPGJmIfl2kZMaT3Wto/igxuqLiHcFEsDjgA3OuU3OuSrgJWB6E+0eAh4GGqbqKOB9AOdcMVAK5Heo4nbKTqmfV19dPCLiXYGEfjawvcH9Qv+yo8zsFGCAc25Wo+cuBaabWZSZ5QGnAQMar8DMZphZgZkVlJSUtOkPCFT9CVoaqy8iXhYVQBtrYtnRjnF/d81jwE1NtHsOGAkUAFuBeUDNcS/m3NPA0wD5+fmd0umuK2iJiAQW+oUcu3eeA+xscD8ZOAmYa2YA/YCZZjbNOVcA/KC+oZnNA9Z3tOj2iIuOJCMpRsM2RcTTAuneWQgMM7M8M4sBrgZm1j/onCtzzmU453Kdc7nAfGCac67AzBLMLBHAzM4Dapxzq4L/ZwQmW8M2RcTjWt3Td87VmNkdwDtAJPCcc26lmT0IFDjnZrbw9D7AO2ZWB+wAbghG0e2VkxrPqp0HQlmCiEhIBdK9g3NuNjC70bIHmmk7tcHtLcDw9pcXXDkp8cxZVURdnSMioqlDFSIi4c0zZ+SCb0+/qqaOPYeOhLoUEZGQ8FTo10+xvF39+iLiUZ4K/aNj9TWCR0Q8ylOh/9VZuRqrLyLe5KnQT4yNIjUhWmfliohneSr0QVMsi4i3eS70s1Pi1b0jIp7ludCvv4KWc5pXX0S8x5OhX1ldx97yqlCXIiLS5TwX+tmaYllEPMxzof/VFMsKfRHxHs+Ffrbm1RcRD/Nc6PeKi6ZXXJTOyhURT/Jc6IPG6ouId3ky9LNTNVZfRLzJk6GfkxrPjv0aqy8i3uPR0E+gvKqW0orqUJciItKlPBn69bNt6mCuiHiNJ0M/R8M2RcSjPB762tMXEW/xZOj3jo8mKTZKoS8inuPJ0DczclLjFfoi4jmeDH3wHczVgVwR8RrPhn6OTtASEQ/ybOhnp8ZzsLKGssMaqy8i3uHZ0M/RvPoi4kEeDn2N1RcR7/Fs6NeflasRPCLiJZ4N/bTEGOKjIzWCR0Q8xbOh/9VYfXXviIh3eDb0wTeCR3v6IuIlng59nZUrIl7j6dDPTkmgtKKaQ0dqQl2KiEiXCCj0zexCM1trZhvM7L4W2l1hZs7M8v33o83seTNbbmarzez+YBUeDPXDNjVWX0S8otXQN7NI4AngImAUcI2ZjWqiXTJwF7CgweIrgVjn3MnAacBtZpbb8bKDQ2P1RcRrAtnTHwdscM5tcs5VAS8B05to9xDwMFDZYJkDEs0sCogHqoADHSs5eLJTdQUtEfGWQEI/G9je4H6hf9lRZnYKMMA5N6vRc18DyoFdwDbgf5xz+xqvwMxmmFmBmRWUlJS0pf4OyUyKJTYqQgdzRcQzAgl9a2KZO/qgWQTwGHBPE+3GAbVAFpAH3GNmg497Meeeds7lO+fyMzMzAyo8GMyM7BSN1RcR74gKoE0hMKDB/RxgZ4P7ycBJwFwzA+gHzDSzacC1wNvOuWqg2Mw+A/KBTUGoPSiyU+N1IFdEPCOQPf2FwDAzyzOzGOBqYGb9g865MudchnMu1zmXC8wHpjnnCvB16ZxtPonABGBN0P+KDshJTVD3joh4Rquh75yrAe4A3gFWA68451aa2YP+vfmWPAEkASvwfXj80Tm3rIM1B1VOajx7y6uoqNJYfREJf4F07+Ccmw3MbrTsgWbaTm1w+xC+YZvdVv2wzZ2lhxnaJznE1YiIdC5Pn5ELX4X+dnXxiIgHeD70s1N0BS0R8Q7Ph36f5FiiI00Hc0XEEzwf+hERRpbG6ouIR3g+9MHXr6+pGETECxT6QE6KxuqLiDco9PGdlVty8AiV1bWhLkVEpFMp9Dl2rL6ISDhT6APZKfXz6iv0RSS8KfSBnDT/WH3t6YtImFPoA32TY4mKMA3bFJGwp9AHoiIj6Nc7Tt07IhL2FPp+OZpXX0Q8QKHvp3n1RcQLFPp+2SnxFB2spKqmLtSliIh0GoW+X05qPM7BrjLt7YtI+FLo+2X7T9BavG1/iCsREek8Cn2/UwakMrxvMve+toy/LyoMdTkiIp1Coe8XHxPJq7dPZFxeGve8upTfvLcO51yoyxIRCSqFfgO94qL5403juOK0HH7z3nrueXWpDuyKSFgJ6MLoXhITFcGvrxjNwLQEHp2zjl2llfz+htPoHR8d6tJERDpMe/pNMDPuOmcYj141hoKt+7jid/PYvk9TNIhIz6fQb8Hlp+bw/M3j2H2gksuenMeywtJQlyQi0iEK/VacMSSDf9x+BrFREXzrqfm8t6oo1CWJiLSbQj8Aw/om8/r3zmBY3yRmvFDA8/O2hLokEZF2UegHqE9yHC/NmMA5I/vy05kreWjWKmrrNKRTRHoWhX4bJMRE8fvrT+OmM3J59tPNfPfFRRyu0nV1RaTnUOi3UWSE8bNpJ/LApaN4d1UR1/xhPnsOHQl1WSIiAVHot9PNk/L4/fWnsWb3AS578jM2lhwKdUkiIq1S6HfABSf246UZEzlcVcvlT85jwaa9oS5JRKRFCv0OGjsghde/eybpSTHc8OwXvDB/a0jm7KmrcyzYtJfqWk0bISLNU+gHwYC0BP5x+xlMHJLOf/5zBbe9sIj95VVdtv6dpYe54bkFfOvp+Tz10cYuW6+I9DwBhb6ZXWhma81sg5nd10K7K8zMmVm+//51ZrakwU+dmY0NVvHdSUpCDH+86XR+cslIPlxbzEW//YTPN3Zud49zjte/LOSC33zMl9tKGZyZyPOfb+VIjUYUiUjTWg19M4sEngAuAkYB15jZqCbaJQN3AQvqlznnXnTOjXXOjQVuALY455YEq/juJiLCuHXyYF7/7pnEx0Ry7TPzeeTdtdR0QpfLvvIqvvfXxfzg5aUM75vM23dP4WdfP5GSg0f419JdQV+fiISHQPb0xwEbnHObnHNVwEvA9CbaPQQ8DFQ28zrXAH9rV5U9zEnZvZl15ySuODWHxz/YwFVPfR7UCds+WFPEBb/5mDmrivj3C0fw8m0TGZiewORhGZzQN4lnP92sawGISJMCCf1sYHuD+4X+ZUeZ2SnAAOfcrBZe51t4JPQBEmOj+PWVY/jfa05hfdEhLv7fT/jX0p0des3yIzXc/4/l3PynAtITY3jje5O4feoQIiMM8M0OeuukwazedaDTu5ZEpGcKJPStiWVHdyPNLAJ4DLin2RcwGw9UOOdWNPP4DDMrMLOCkpKSAErqOaaNyWL23ZMZ2ieJO//2Jfe+tpSKqpo2v07Bln1c9NtPeGnhNm47azBv3HEmo7J6Hb++sVlkJMXwzKebg1G+iISZQEK/EBjQ4H4O0HCXNRk4CZhrZluACcDM+oO5flfTwl6+c+5p51y+cy4/MzMz0Np7jAFpCbxy20Tu+NpQXl1UyKWPf8qKHWUBPbeqpo5fvb2Gq576HIfj5RkTuf+ikcRGRTbZPi46kusnDOKDNcU6YUxEjhNI6C8EhplZnpnF4AvwmfUPOufKnHMZzrlc51wuMB+Y5pwrgKPfBK7EdyzAs6IjI/jRBcN58dbxlB+p4fIn57Xa975m9wGmP/EZv5u7kavyB/DW3VMYl5fW6rqunzCImKgIntPevog00mroO+dqgDuAd4DVwCvOuZVm9qCZTQtgHVOAQufcpo6VGh7OGJLBW3dPYcoJmTw0axU3/2nhcXP31NY5nvpoI9Me/4ySg5U88+18fvnN0STFBnZ1y4ykWC4bm83fFxd26fkCItL9WXcb5ZGfn+8KCgpCXUanc87xwvyt/Nebq+kVF81j3xrD5GGZbN9XwT2vLOWLLfu44MS+/Pyyk0lPim3z668rOsj5j33Mv10wnO99bWgn/AUi0p2Y2SLnXH5r7XRh9BAxM749MZfTc9O4629fcsOzXzBtTBbvry4iwoxHrhzD5admY9bUcfTWndA3mSknZPKneVu4dXJes8cARMRbNA1DiI3s34uZd0ziuvEDmbl0J6NzUnjr+5P55mk57Q78erdMyqPk4BFm6WQtEfFT9043sn1fBdkp8UREdCzs6znnOP+xj4mOjODNuyZ1+ENEurcl20vJTU8gJSEm1KVICATavaM9/W5kQFpC0AIf/CdrTc5j1a4DfK5pn8Pa2t0HufzJz3hw1qpQlyLdnEI/zE0fm016YgzPfqLhm+Hsv2evps7BrGW7KK3QiC1pnkI/zNWfrPX+mmI26WStsDR3bTEfryvhytNyqKqp47VFhaEuSboxhb4HHD1Z67Pw3duvrq3jt++tZ0PxwVCX0qVqauv4+ezV5KYn8N+Xncxpg1J5ccE2T064N2dVEeuLvPXv3x4KfQ/ITI7lG2OzeG1R+J6s9T/vruWx99bxvRe/pKrGO1cPe7lgO+uKDnHfRSOIiYrguvED2bynnHkem3Bv855yvvOXRfzo1aWe/MBrC4W+R9wyaTCV1XX89YttoS4l6D5cU8xTH20if1Aqa4sO8sSHG0JdUpc4WFnNo++uY1xuGhec2A+Ai0/uT0pCNC8u2Bri6rrWb95bR22dY2lhGV9s3hfqcro1hb5HDO+XzORhGTw/b0tY7QnvLD3MD19Zwsj+vfjLreOZPjaLJz7cwJrdB0JdWqd7cu5G9pZX8ZNLRx4djhsXHcmVp+Xw7soiig80d2mL8LJ290FmLt3JTWfkkpYYwx8+0YwvLVHoe8gtk/IoPniEWcs6Nq9/d1FdW8edf/N15zx53anERUfy06+fSO/4aO59bVmnXLGsu9i+r4JnP93M5adkMzon5ZjHrh0/iJo6x8sLtzfz7PDy2Jx1JMVE8f1zh3HDhEG8t7rYc8d22kKh7yFnnZDJsD5JPPNJeFxZ65F317Fo635+8c3R5GUkApCWGMPPpp3IssIyng3jWUYffmctEQY/umD4cY/lZSQyaWgGf/tiG7V1Pf/fuSXLC8t4e+VubpmcR0pCDN+eOIjYqAie0RDlZin0PcTMuGWS72St+Zt6dr/nh2uK+f1HG7l2/ECmjck65rFLR/fnvFF9eXTOurAcprp4237+tXQnMyYPJislvsk2108YyM6ySj5cU9zF1XWtR+esJSUhmpsn5QGQnhTLFafl8I/FOyg+6I3urbZS6HvMN07JJi0xhmc/7bn9nrvKvurHf+DSUcc9bmb81zdOIiYqgvv+vpy6MNrbdc7xX7NWkZkcy21nDWm23Tkj+9InOZa/hPEB3UVb9/Hh2hJumzKEXnHRR5ffMimP6ro6/jwvfP/2jlDoe0z9yVrvre6ZJ2vV1NZx5199/fhPXHsKcdFNzx7at1cc/3nJKL7Ysi+sRrK8uXwXi7eV8m/nDyexhesrREdGcPW4gXy0roTt+yq6sMKu88i768hIiuXGMwYds3xwZhLnjezLC/O3tuvSpOFOoe9BN0wYRExkBH/8bEuoS2mzR+aso2Drfn5++ckMzkxqse2V+TlMHpbBL99aQ+H+nh98ldW1/PKtNYzs34tvnpbTavurTx+AQVgO0523YQ/zNu7le18bQkLM8R9+t501mLLD1bxaoLOTG1Poe1BmcizTx2bx6qLtPWqelg/XFvO7uRu5ZtxApo/NbrW9mfHzy07GAf/x+ooef/D6T/O2ULj/MD+5ZCSRAUzMl5USzzkj+/LKwu1hNUzXOccjc9bRv3cc14wb2GSb0walcerAFJ75dFNYj+JqD4W+R90yOY/K6jpeXNAz9gJ3lR3mhy8vYUS/ZH769eP78ZszIC2Bey8YzsfrSvj74h2dWGHn2nPoCE98sIFzRvThzKEZAT/vuvED2Vtexdsrd3didV1r7roSFm3dz51nD2u2ew9gxpTBbN93mHdWFnVhdd2fQt+jRvTrxeRhGfz58+5/slZNbR13+cfjP+Efj98W356YS/6gVB6atarHjuj4zXvrqKiu5f6LR7bpeVOGZTIgLZ4X54fHcQ3nHI+8u5aBaQlcmd9yF9d5o/qRm57A0x9v7PHf8oJJoe9hN0/Ko+jAEd5c3r1P1np0zjoWbvH14w9ppR+/KRERxq+uGM3h6lr+8589r5tnfdFB/rpgG9ePH8jQPm37+yMijGvHDWLB5n1hMRnZOyt3s2LHAe4+ZxjRkS3HV2SEccvkwZqaoRGFvoedNSyTod38ZK25a4t5cu5Grhk3IKB+/OYMyUzi++cO452VRcxe3rO6On4+ezWJsVHcfe4J7Xr+Vfk5REdaj+nKa05tnePROesYkpnIN04J7L1wxak5mpqhEYW+h0VE+E7WWrnzAAu64Z6Qbzz+Un8//okdfr0ZkwdzUnYvfjpzRY+ZbfTjdSV8uLaEu84eRlpi+y6DmJ4Uy0Un9efviwt79BDGWct2sq7oED8474SADmQDxMdENpiaoecNUe4MCn2Pu8x/slZ3O229vh+/srq2Xf34TYmKjODhb46htKK6R1xWsLbO8d9vrmZgWgLfbjQWva2unzCIg5U1zFq6K0jVda2a2joem7OOEf2Sufik/m167ldTM2hvHxT6nhcXHcn14wfy/poiNu8pD3U5Rx3tx7+sff34zRmV1Yvbpw7h9S93dPspCl4p2M7aooPcf9EIYqM69qF3em4qJ/RN6rFn6P5j8Q627K3gnvOHt/k60pqa4VgKfeH6iYOIjojgj93kylofrSs52o8faN9tW9xx9lCG9UniP15fzsHK6qC/fjAcOlLDI++u5fTcVC48qV+HX8/MuG78IJYVlrGssDQIFXadIzW1/Pb99YzJ6c25I/u06zXqp2Z44fOe+aEXTAp9oU9yHNPGZvFqQSFfbN7H5j3l7C+vCskMjbvLKvnB0fH4He/Hb0psVCS/umI0uw9U8ou31nTKOjrqd3M3sOdQFT+5ZNTRufI76rJTs4mPjuTF+T3rgO4rC7ezo/Qw95w/vN3bQlMzfKX5yTvEU26dnMcbS3Zw1VOfH11mBr3ioklJiCYlPpreCTGkxEeTmvDV7ZSE+h/f/eS4aJLjooiNimjzf9CG/fj/79rg9OM359SBqdxyZh7PfLqZr4/OYuKQ9E5bV1sV7q/gD59s5htjsxgzIKX1JwSoV1w008dm8caSnfzHJSPpHR/d+pNCrLK6lsc/2MC43DQmDwv8pLSm3HbWYN5dVcSrBYXceEZucArsgRT6AvhO1nrvh2exqaSc/RVVlFZUU3q4mrKKKkoPV7O/wnd7695ySiuqOVBZTUujPKMijKS4KJJifT/J9bfjoo+9HxtFUlwUybFRzNu4ly+27OM33xrb5vHo7XHP+cOZs7qI+/6xjLfvnkJ8TOd9yLTFr99ZiwH/duGIoL/2deMH8dLC7by+uJCbzswL+usH21/mb6X44BEev+aUDn/jaTg1w/UTBgU8AijcKPTlqEHpiQxKTwyobW2d48Bh3wdDqf+DobSiikOVNRw8UsOhyhoO+X/X399zqIoteys4WFnDoSPVVFYffybw1ad3Tj9+U+JjIvnl5aO55g/zeeTdtfykiWmau9qS7aW8sWQnd3xtKNnNzJXfESfn9GZMTm9eXLCNG8/IDVrXUWcoP1LDk3M3MnlYBuMHB+eb2Iwpg/nOXxbz9ordXDK6baOAwoVCX9olMsJITYwhNTEGCOyDorHq2jrKj9T4PwRqqK6t46Ss3sEttBUTh6Rz7fiBPPfZZi4Z3Z9TBqZ26fobqp8rPyMplu9MbX6u/I66bvwg7v37MhZu2c+4vLROW09H/fGzzewrr+Ke84+/Olh7NZya4eKT+3XrD73OogO5EjLRkRGkJMQwIC2Bkf17MTonpc3D8YLh/otG0LdXHPe+towjNbVdvv56b63YTcHW/fzo/BNIamGu/I76+pgskuOi+Es3no+nrKKapz7exLkj+zI2iMc1Gk7NsHDL/qC9bk+i0BfPS46L5ueXncz64kNM/fVcfvz6cj5cU0xlded/AFTX1jF/015+8dZqHnhjBSP6JXNl/oBOXWd8TCTfPDWHt1bsYs+hI526rvZ65tNNHKys4YfntW/qiZbUT83w9Mcbg/7aPYG6d0SAr43ow++uO5V/LtnB61/u4MUF24iPjmTSsAzOHdmHs0f0JTM5Nijr2l1WyUfrivlwTQmfbdjDwSM1REca+YPS+Mmlgc2V31HXTxjIn+Zt4dWCQm7vxK6k9th76AjPferrbhuV1Svor18/NcNv31/PhuJDXTJooDsJKPTN7ELgt0Ak8Ixz7pfNtLsCeBU43TlX4F82GngK6AXU+R/TaXHS7Vx0cn8uOrk/ldW1zN+0l/dXF/P+6iLmrCrCbDljclI4d2QfzhnZlxH9kgPuD66prWPxtlI+XFvM3LUlrN51AID+veO4dEx/pg73zZHfmV06jQ3tk8z4vDT++sVWbpsyOCTdas156uNNHK6u5QfnDuu0dXx74iB+/9FGnv10E7+4fHSnrac7stZmVzSzSGAdcB5QCCwErnHOrWrULhl4E4gB7nDOFZhZFLAYuME5t9TM0oFS51yz35vz8/NdQUFBR/4mkaBxzrF610HeX13Ee2uKWbrddzZrdko85/g/ACYMTjtumoTiA5XMXVfC3LXFfLJ+Dwcra4iKMPJzU5k6vA9Th2cyvG/gHxyd4V9Ld3Ln377kT//ndKYOb9+ZrsFWfKCSyQ9/yCWj+/PoVWM7dV0/fn05ry4q5LN/Pzto3+JCycwWOefyW2sXyK7FOGCDc26T/4VfAqYDjWesegh4GPhRg2XnA8ucc0sBnHN7A1ifSLdhZozK6sWorF7cec4wig9U8sGaYt5bXcwrBdv58+dbSYyJZMoJmZx1Qibb91cwd20JK3f69ub79orl4pP687URmZw5NIPkuO5zQtQFJ/YjIymGFxds6zah/8SHG6itc9x9Tuft5de7ZVIef/1iG3/+fEtQRwh1d4GEfjawvcH9QmB8wwZmdgowwDk3y8wahv4JgDOzd4BM4CXn3MMdrFkkZPr0iuPqcQO5etxAKqtrmbdxD++tLuaD1cW8tWI3kRHGaYNSuffC4Uw9oQ8j+4d2b74lMVERXJU/gN9/tJEPbsECAAAJnklEQVSdpYfJ6oTzAtqicH8Ff/1iG1fmDwj4fJGOaDg1w+1Tm77AejgK5K9s6h17tE/IzCKAx4Cbmnn9ScDpQAXwvv8ryPvHrMBsBjADYODApi90LNLdxEVHcvaIvpw9oi/uG46NJYfITI7rEdMb1Ltm3EB+99FGXlq4vVNGygSqrKKaX7y1BsO48+yhXbbeYE3NUFldy9LtpTggLyORPsmx3fbDPpDQLwQajiHLARpeXy8ZOAmY6/8j+wEzzWya/7kfOef2AJjZbOBU4JjQd849DTwNvj79dv0lIiFkZgztkxzqMtpsQFoCU0/I5KUvtnHn2UNbvQRhMB2uquX9NUW8sWQnc9cWU13ruH3qkC79xtHeqRmqaupYWljKvA17+XzTHhZvKz3mWtMJMZEMSk8kLyOB3PREcjMSyctIJDc9kYykmJB+IAQS+guBYWaWB+wArgaurX/QOVcGHJ0JyczmAj/yH8jdCNxrZglAFXAWvm8FItJNXDd+ELf+uYD3VxdxYRsvUNJW1bV1fLphD/9aspN3Vu6mvKqWvr1iuXFiLtPGZnFydteekQ1fTc3wzsrdXHxy039/TW0dy3aU8fnGvczftJeCLfs5XF2LGYzs14sbJgxi4uB0YqIi2LK3nM17ytmyp5zVuw7y7soiahrMWJscG8Ug/4dB/QdBbkYigzMS/We4d65WQ985V2NmdwDv4Buy+ZxzbqWZPQgUOOdmtvDc/Wb2KL4PDgfMds69GaTaRSQIvjaiD1m94/jL/G2dEvp1dY7F2/bzxpKdvLl8F/vKq+gVF8XXx2QxbWwW4/PSQzr5Wf3UDE99vImLTvJNzVBb51i50xfyn2/ay8LN+yiv8g06HN43mW+dPoAJg9OZMDiNlIRjg3oKmcfcr66tY8f+w2ze6/sg2LKnnM17K1hWWMbs5btoOIP5JaP788S1p3bq39vqkM2upiGbIl3v8ffX88icdbz9/clBG0q6ZvcB3liyk5lLdrKj9DBx0RGcO7Iv08ZkcdbwzA5fDSyYXpi/lf/85wpmTBnMppJDLNi8j4OVvnn3B2cmcsaQdCYOzmD84DQykoI3vLOqpo7t+yt8HwR7yslKiW/220ZrAh2yqdAXEYoPVnLGLz6gps4R5Z9MLz0xhtSEGNISfT9HlzV4LD0phpSE6KMBvn1fBTOX+oJ+bdFBIiOMycMymD42i/NG9evSE9Da4nBVLZMf/oA9h6oYlJ7AxMHpTBySzoTB6fTtFRfq8gKi0BeRNvl8415W7ixjX3kV+yuq2HvI/7u8iv3lvumzm4uLZP81EnaW+U62zx+UyvSxWVx8cn/Sg7hn3Jl2l1VS51zIh662VzBPzhIRD5g4JL3FK4jV1NZRdriafeVVX/1UVLHvkO93aUU1Q/skMW1MFgPSErqw8uDo17tn7NF3lEJfRAISFRlBelJsj9lzl6ZpamUREQ9R6IuIeIhCX0TEQxT6IiIeotAXEfEQhb6IiIco9EVEPEShLyLiId1uGgYzKwG2duAlMoA9QSqnM6i+jlF9HaP6OqY71zfIOZfZWqNuF/odZWYFgcw/ESqqr2NUX8eovo7p7vUFQt07IiIeotAXEfGQcAz9p0NdQCtUX8eovo5RfR3T3etrVdj16YuISPPCcU9fRESa0SND38wuNLO1ZrbBzO5r4vFYM3vZ//gCM8vtwtoGmNmHZrbazFaa2d1NtJlqZmVmtsT/80BX1deghi1mtty//uMuVWY+/+vfhsvMrHOv1vzVeoc32C5LzOyAmX2/UZsu335m9pyZFZvZigbL0sxsjpmt9/9Obea5N/rbrDezG7uwvl+b2Rr/v9/rZpbSzHNbfC90Yn0/M7MdDf4dL27muS3+f+/E+l5uUNsWM1vSzHM7ffsFlXOuR/0AkcBGYDAQAywFRjVq813g9/7bVwMvd2F9/YFT/beTgXVN1DcVmBXi7bgFyGjh8YuBtwADJgALQvRvvRvf+OOQbj9gCnAqsKLBsoeB+/y37wN+1cTz0oBN/t+p/tupXVTf+UCU//avmqovkPdCJ9b3M+BHAbwHWvz/3ln1NXr8EeCBUG2/YP70xD39ccAG59wm51wV8BIwvVGb6cDz/tuvAeeYmXVFcc65Xc65xf7bB4HVQHZXrDvIpgN/dj7zgRQz69/FNZwDbHTOdeRkvaBwzn0M7Gu0uOH77HngG0089QJgjnNun3NuPzAHuLAr6nPOveucq/HfnQ/kBHu9gWpm+wUikP/vHdZSff7suAr4W7DXGwo9MfSzge0N7hdyfKgebeN/05cBzV/8s5P4u5VOARY08fBEM1tqZm+Z2YldWpiPA941s0VmNqOJxwPZzp3tapr/jxbq7QfQ1zm3C3wf9kCfJtp0h+0IcDO+b25Nae290Jnu8Hc/PddM91h32H6TgSLn3PpmHg/l9muznhj6Te2xNx6CFEibTmVmScDfge875w40engxvi6LMcDjwD+7sja/M51zpwIXAd8zsymNHg/pNjSzGGAa8GoTD3eH7Reo7vBe/DFQA7zYTJPW3gud5XfAEGAssAtfF0pjId9+wDW0vJcfqu3XLj0x9AuBAQ3u5wA7m2tjZlFAb9r31bJdzCwaX+C/6Jz7R+PHnXMHnHOH/LdnA9FmltFV9fnXu9P/uxh4Hd/X6IYC2c6d6SJgsXOuqPED3WH7+RXVd3n5fxc30Sak29F/4PhS4Drn74BuLID3QqdwzhU552qdc3XAH5pZb6i3XxRwOfByc21Ctf3aqyeG/kJgmJnl+fcGrwZmNmozE6gfJXEF8EFzb/hg8/f/PQusds492kybfvXHGMxsHL5/h71dUZ9/nYlmllx/G98BvxWNms0Evu0fxTMBKKvvyugize5dhXr7NdDwfXYj8EYTbd4BzjezVH/3xfn+ZZ3OzC4E/h2Y5pyraKZNIO+Fzqqv4TGiy5pZbyD/3zvTucAa51xhUw+Gcvu1W6iPJLfnB9/IknX4jur/2L/sQXxvboA4fN0CG4AvgMFdWNskfF8/lwFL/D8XA98BvuNvcwewEt9IhPnAGV28/Qb7173UX0f9NmxYowFP+LfxciC/C+tLwBfivRssC+n2w/cBtAuoxrf3eQu+40TvA+v9v9P8bfOBZxo892b/e3ED8H+6sL4N+PrD69+H9SPasoDZLb0Xuqi+F/zvrWX4grx/4/r894/7/94V9fmX/6n+fdegbZdvv2D+6IxcEREP6YndOyIi0k4KfRERD1Hoi4h4iEJfRMRDFPoiIh6i0BcR8RCFvoiIhyj0RUQ85P8DSHaaUggYbqEAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x19f2e528710>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def train_rnn(num_epochs, num_steps, state_size=4, verbose=True):\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        #sess = tf_debug.LocalCLIDebugWrapperSession(sess)\n",
    "        training_losses = []\n",
    "        for idx, epoch in enumerate(gen_epochs(num_epochs, num_steps)):\n",
    "            training_loss = 0\n",
    "            training_state = np.zeros((batch_size, state_size))   # ->(200, 4)\n",
    "            if verbose:\n",
    "                print('\\nepoch', idx)\n",
    "            for step, (X, Y) in enumerate(epoch):\n",
    "                tr_losses, training_loss_, training_state, _ = \\\n",
    "                    sess.run([losses, total_loss, final_state, train_step], feed_dict={x:X, y:Y, init_state:training_state})\n",
    "                training_loss += training_loss_\n",
    "                if step % 100 == 0 and step > 0:\n",
    "                    if verbose:\n",
    "                        print('第 {0} 步的平均损失 {1}'.format(step, training_loss/100))\n",
    "                    training_losses.append(training_loss/100)\n",
    "                    training_loss = 0\n",
    "    return training_losses\n",
    "\n",
    "training_losses = train_rnn(num_epochs=5, num_steps=num_steps, state_size=state_size)\n",
    "print(training_losses[0])\n",
    "plt.plot(training_losses)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
